{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0XQ6NsIuDtgr"
   },
   "source": [
    "# Self-Attention的实现\n",
    "\n",
    "[文章链接](https://zhuanlan.zhihu.com/p/347492368)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U76qWlrbOmx7"
   },
   "source": [
    "![texto alternativo](https://pic2.zhimg.com/80/v2-b900fb952a100acd7dd8cd65ebd8bd61_1440w.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "atUYzU3TSD9z"
   },
   "source": [
    "### 第0步. 什么是self-attention?\n",
    "\n",
    "请移步查看 [Transformer 一篇就够了（一）： Self-attenstion](https://zhuanlan.zhihu.com/p/345680792)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SDMmHAaSTE6P"
   },
   "source": [
    "接下来，我们将要解释和实现self-attention的全过程。\n",
    "- 准备输入\n",
    "- 初始化参数\n",
    "- 获取key，query和value\n",
    "- 给input1计算attention score\n",
    "- 计算softmax\n",
    "- 给value乘上score\n",
    "- 给value加权求和获取output1\n",
    "- 重复步骤4-7，获取output2，output3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "u1UxPJlHBVmS",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ENdzUZqSBsiB"
   },
   "source": [
    "### 第1步: 准备输入\n",
    "\n",
    "为了简单起见，我们使用3个输入，每个输入都是一个4维的向量。\n",
    "\n",
    "![](https://pic1.zhimg.com/80/v2-071dfa785114e675be5dff040f4626e6_1440w.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 70
    },
    "id": "jKYrJsljBhnv",
    "outputId": "7b865905-2151-4a6a-a899-5439aa429af4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 0., 1., 0.],\n",
       "        [0., 2., 0., 2.],\n",
       "        [1., 1., 1., 1.]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = [\n",
    "  [1, 0, 1, 0], # Input 1\n",
    "  [0, 2, 0, 2], # Input 2\n",
    "  [1, 1, 1, 1]  # Input 3\n",
    " ]\n",
    "x = torch.tensor(x, dtype=torch.float32)\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DZ96EoE1Bvat"
   },
   "source": [
    "### 第2步: 初始化参数\n",
    "\n",
    "![](https://pic1.zhimg.com/80/v2-d6b0f7707e9af39d40361ff3088cea2c_1440w.gif)\n",
    "\n",
    "每一个输入都有三个表示，分别为key（橙黄色）query（红色）value（紫色）。比如说，每一个表示我们希望是一个3维的向量。由于输入是4维，所以我们的参数矩阵为 4\\times3 维。\n",
    "\n",
    "`后面我们会看到，value的维度，同样也是我们输出的维度。`\n",
    "\n",
    "为了能够获取这些表示，每一个输入（绿色）要和key，query和value相乘，在我们例子中，我们使用如下的方式初始化这些参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 284
    },
    "id": "jUTNr15JBkSG",
    "outputId": "baa4c379-6174-4990-8cd2-51191e904550"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights for key: \n",
      " tensor([[0., 0., 1.],\n",
      "        [1., 1., 0.],\n",
      "        [0., 1., 0.],\n",
      "        [1., 1., 0.]])\n",
      "Weights for query: \n",
      " tensor([[1., 0., 1.],\n",
      "        [1., 0., 0.],\n",
      "        [0., 0., 1.],\n",
      "        [0., 1., 1.]])\n",
      "Weights for value: \n",
      " tensor([[0., 2., 0.],\n",
      "        [0., 3., 0.],\n",
      "        [1., 0., 3.],\n",
      "        [1., 1., 0.]])\n"
     ]
    }
   ],
   "source": [
    "w_key = [\n",
    "  [0, 0, 1],\n",
    "  [1, 1, 0],\n",
    "  [0, 1, 0],\n",
    "  [1, 1, 0]\n",
    "]\n",
    "w_query = [\n",
    "  [1, 0, 1],\n",
    "  [1, 0, 0],\n",
    "  [0, 0, 1],\n",
    "  [0, 1, 1]\n",
    "]\n",
    "w_value = [\n",
    "  [0, 2, 0],\n",
    "  [0, 3, 0],\n",
    "  [1, 0, 3],\n",
    "  [1, 1, 0]\n",
    "]\n",
    "w_key = torch.tensor(w_key, dtype=torch.float32)\n",
    "w_query = torch.tensor(w_query, dtype=torch.float32)\n",
    "w_value = torch.tensor(w_value, dtype=torch.float32)\n",
    "\n",
    "print(\"Weights for key: \\n\", w_key)\n",
    "print(\"Weights for query: \\n\", w_query)\n",
    "print(\"Weights for value: \\n\", w_value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8pr9XZF9X_Ed"
   },
   "source": [
    "Note: *通常在神经网络的初始化过程中，这些参数都是比较小的，一般会在Gaussian, Xavier and Kaiming distributions随机采样完成。*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UxGT5awVB1Xw",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 第3步：获取key，query和value\n",
    "\n",
    "![](https://pic1.zhimg.com/80/v2-92c80cc4a2741e48f678366316f2f57c_1440w.gif)\n",
    "\n",
    "现在我们有了三个参数，现在就让我们来获取实际上的key，query和value。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VQwhDIi7aGXp"
   },
   "source": [
    "keys的表示为：\n",
    "```\n",
    "               [0, 0, 1]\n",
    "[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]\n",
    "[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]\n",
    "[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qi0EblXTamFz"
   },
   "source": [
    "values的表示为：\n",
    "```\n",
    "               [0, 2, 0]\n",
    "[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3] \n",
    "[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]\n",
    "[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]\n",
    "```\n",
    "![](https://pic2.zhimg.com/80/v2-27850c32ead506551de9c088e35e4a67_1440w.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GTp2izu1bLNq"
   },
   "source": [
    "querys的表示为：\n",
    "```\n",
    "               [1, 0, 1]\n",
    "[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]\n",
    "[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]\n",
    "[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]\n",
    "```\n",
    "![](https://pic1.zhimg.com/80/v2-ee0e36af3dbade8150642f57260cfc4a_1440w.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qegb9M0KbnRK"
   },
   "source": [
    "Notes: *在我们实际的应用中，有可能会在点乘后，加上一个bias的向量。*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 230
    },
    "id": "rv2NXynOB7oG",
    "outputId": "a2656b52-4b1d-4726-9d42-522f941b3126",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys: \n",
      " tensor([[0., 1., 1.],\n",
      "        [4., 4., 0.],\n",
      "        [2., 3., 1.]])\n",
      "Querys: \n",
      " tensor([[1., 0., 2.],\n",
      "        [2., 2., 2.],\n",
      "        [2., 1., 3.]])\n",
      "Values: \n",
      " tensor([[1., 2., 3.],\n",
      "        [2., 8., 0.],\n",
      "        [2., 6., 3.]])\n"
     ]
    }
   ],
   "source": [
    "keys = x @ w_key\n",
    "querys = x @ w_query\n",
    "values = x @ w_value\n",
    "\n",
    "print(\"Keys: \\n\", keys)\n",
    "# tensor([[0., 1., 1.],\n",
    "#         [4., 4., 0.],\n",
    "#         [2., 3., 1.]])\n",
    "\n",
    "print(\"Querys: \\n\", querys)\n",
    "# tensor([[1., 0., 2.],\n",
    "#         [2., 2., 2.],\n",
    "#         [2., 1., 3.]])\n",
    "print(\"Values: \\n\", values)\n",
    "# tensor([[1., 2., 3.],\n",
    "#         [2., 8., 0.],\n",
    "#         [2., 6., 3.]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3pmf0OQhCnD8"
   },
   "source": [
    "### 第4步: 计算 attention scores\n",
    "![](https://pic2.zhimg.com/80/v2-d8a8cb449ff5edc1f47b262aa1e472af_1440w.gif)\n",
    "\n",
    "为了获取input1的attention score，我们使用点乘来处理所有的key和query，包括它自己的key和value。这样我们就能够得到3个key的表示（因为我们有3个输入），我们就获得了3个attention score（蓝色）。\n",
    "\n",
    "```\n",
    "            [0, 4, 2]\n",
    "[1, 0, 2] x [1, 4, 3] = [2, 4, 4]\n",
    "            [1, 0, 1]\n",
    "```\n",
    "这里我们需要注意一下，这里我们只有input1的例子。后面，我们会对其他的输入的query做相同的操作。       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 70
    },
    "id": "6GDhKEl0Cokw",
    "outputId": "c91356df-202c-4816-e98d-eefd1e1031d3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 2.,  4.,  4.],\n",
      "        [ 4., 16., 12.],\n",
      "        [ 4., 12., 10.]])\n"
     ]
    }
   ],
   "source": [
    "attn_scores = querys @ keys.T\n",
    "print(attn_scores)\n",
    "\n",
    "# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1\n",
    "#         [ 4., 16., 12.],  # attention scores from Query 2\n",
    "#         [ 4., 12., 10.]]) # attention scores from Query 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bO3NmnbvCxpX"
   },
   "source": [
    "### 第5步: 计算softmax\n",
    "![](https://pic1.zhimg.com/80/v2-2115ca2a23e202c48b986d02c0e18158_1440w.gif)\n",
    "\n",
    "给attention score应用softmax。\n",
    "```\n",
    "softmax([2, 4, 4]) = [0.0, 0.5, 0.5]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 124
    },
    "id": "PDNzdZHVC1ys",
    "outputId": "c528a7be-5c26-46a9-8fdb-1f2b029b6b93"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n",
      "        [6.0337e-06, 9.8201e-01, 1.7986e-02],\n",
      "        [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n",
      "tensor([[0.0000, 0.5000, 0.5000],\n",
      "        [0.0000, 1.0000, 0.0000],\n",
      "        [0.0000, 0.9000, 0.1000]])\n"
     ]
    }
   ],
   "source": [
    "from torch.nn.functional import softmax\n",
    "\n",
    "attn_scores_softmax = softmax(attn_scores, dim=-1)\n",
    "print(attn_scores_softmax)\n",
    "# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n",
    "#         [6.0337e-06, 9.8201e-01, 1.7986e-02],\n",
    "#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n",
    "\n",
    "# For readability, approximate the above as follows\n",
    "attn_scores_softmax = [\n",
    "  [0.0, 0.5, 0.5],\n",
    "  [0.0, 1.0, 0.0],\n",
    "  [0.0, 0.9, 0.1]\n",
    "]\n",
    "attn_scores_softmax = torch.tensor(attn_scores_softmax)\n",
    "print(attn_scores_softmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iBe71nseDBhb"
   },
   "source": [
    "### 第6步: 给value乘上score\n",
    "![](https://pic2.zhimg.com/80/v2-5c8017097df554dbaaba2b6acad43a11_1440w.gif)\n",
    "\n",
    "使用经过softmax后的attention score乘以它对应的value值（紫色），这样我们就得到了3个weighted values（黄色）。\n",
    "```\n",
    "1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]\n",
    "2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]\n",
    "3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]\n",
    "``` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 212
    },
    "id": "tNnx-Fx5DFDi",
    "outputId": "abc7a8ec-f964-483a-9bfb-2848f0e8e592"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000]],\n",
      "\n",
      "        [[1.0000, 4.0000, 0.0000],\n",
      "         [2.0000, 8.0000, 0.0000],\n",
      "         [1.8000, 7.2000, 0.0000]],\n",
      "\n",
      "        [[1.0000, 3.0000, 1.5000],\n",
      "         [0.0000, 0.0000, 0.0000],\n",
      "         [0.2000, 0.6000, 0.3000]]])\n"
     ]
    }
   ],
   "source": [
    "weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]\n",
    "print(weighted_values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gU6w0U9ADQIc"
   },
   "source": [
    "### 第7步: 给value加权求和获取output\n",
    "![](https://pic1.zhimg.com/80/v2-28f94e09aba39207017122b271a9858d_1440w.gif)\n",
    "\n",
    "把所有的weighted values（黄色）进行element-wise的相加。\n",
    "```\n",
    "  [0.0, 0.0, 0.0]\n",
    "+ [1.0, 4.0, 0.0]\n",
    "+ [1.0, 3.0, 1.5]\n",
    "-----------------\n",
    "= [2.0, 7.0, 1.5]\n",
    "```\n",
    "\n",
    "得到结果向量[2.0, 7.0, 1.5]（深绿色）就是ouput1的和其他key交互的query representation。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P3yNYDUEgAos",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 第8步: 重复步骤4-7，获取output2，output3\n",
    "![](https://pic1.zhimg.com/80/v2-28f94e09aba39207017122b271a9858d_1440w.gif)\n",
    "\n",
    "现在，我们已经完成output1的全部计算，我们要对input2和input3也重复的完成步骤4～7的计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 70
    },
    "id": "R6excNSUDRRj",
    "outputId": "e5161fbe-05a5-41d2-da1e-5951ce8b1674",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2.0000, 7.0000, 1.5000],\n",
      "        [2.0000, 8.0000, 0.0000],\n",
      "        [2.0000, 7.8000, 0.3000]])\n"
     ]
    }
   ],
   "source": [
    "outputs = weighted_values.sum(dim=0)\n",
    "print(outputs)\n",
    "\n",
    "# tensor([[2.0000, 7.0000, 1.5000],  # Output 1\n",
    "#         [2.0000, 8.0000, 0.0000],  # Output 2\n",
    "#         [2.0000, 7.8000, 0.3000]]) # Output 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oavQirdbhAK7"
   },
   "source": [
    "### 福利: Tensorflow 2 实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "575q0u_ahP-6",
    "outputId": "867a4e88-2223-41e4-ccd5-dbc47f580c83"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "0vjwwEKMhqmZ",
    "outputId": "56e5ed58-e100-434d-a8b2-00325bfc0d40",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[1. 0. 1. 0.]\n",
      " [0. 2. 0. 2.]\n",
      " [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "x = [\n",
    "  [1, 0, 1, 0], # Input 1\n",
    "  [0, 2, 0, 2], # Input 2\n",
    "  [1, 1, 1, 1]  # Input 3\n",
    " ]\n",
    "\n",
    "x = tf.convert_to_tensor(x, dtype=tf.float32)\n",
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337
    },
    "id": "TN-pri7rhwJ-",
    "outputId": "aa8b1395-80a3-41e1-b544-beb06ce65a96"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weights for key: \n",
      " tf.Tensor(\n",
      "[[0. 0. 1.]\n",
      " [1. 1. 0.]\n",
      " [0. 1. 0.]\n",
      " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n",
      "Weights for query: \n",
      " tf.Tensor(\n",
      "[[1. 0. 1.]\n",
      " [1. 0. 0.]\n",
      " [0. 0. 1.]\n",
      " [0. 1. 1.]], shape=(4, 3), dtype=float32)\n",
      "Weights for value: \n",
      " tf.Tensor(\n",
      "[[0. 2. 0.]\n",
      " [0. 3. 0.]\n",
      " [1. 0. 3.]\n",
      " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "w_key = [\n",
    "  [0, 0, 1],\n",
    "  [1, 1, 0],\n",
    "  [0, 1, 0],\n",
    "  [1, 1, 0]\n",
    "]\n",
    "w_query = [\n",
    "  [1, 0, 1],\n",
    "  [1, 0, 0],\n",
    "  [0, 0, 1],\n",
    "  [0, 1, 1]\n",
    "]\n",
    "w_value = [\n",
    "  [0, 2, 0],\n",
    "  [0, 3, 0],\n",
    "  [1, 0, 3],\n",
    "  [1, 1, 0]\n",
    "]\n",
    "w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)\n",
    "w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)\n",
    "w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)\n",
    "print(\"Weights for key: \\n\", w_key)\n",
    "print(\"Weights for query: \\n\", w_query)\n",
    "print(\"Weights for value: \\n\", w_value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 230
    },
    "id": "Jp2DP46Sh19r",
    "outputId": "5c1befaf-e096-454c-8402-885f049752e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[0. 1. 1.]\n",
      " [4. 4. 0.]\n",
      " [2. 3. 1.]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[1. 0. 2.]\n",
      " [2. 2. 2.]\n",
      " [2. 1. 3.]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[1. 2. 3.]\n",
      " [2. 8. 0.]\n",
      " [2. 6. 3.]], shape=(3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "keys = tf.matmul(x, w_key)\n",
    "querys = tf.matmul(x, w_query)\n",
    "values = tf.matmul(x, w_value)\n",
    "print(keys)\n",
    "print(querys)\n",
    "print(values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "tLJDo_bFigkm",
    "outputId": "b5d8e02d-9531-49c8-a587-7a6e0b6f884d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[ 2.  4.  4.]\n",
      " [ 4. 16. 12.]\n",
      " [ 4. 12. 10.]], shape=(3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "attn_scores = tf.matmul(querys, keys, transpose_b=True)\n",
    "print(attn_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 159
    },
    "id": "8QY858MEiibV",
    "outputId": "2e84f48b-a4ed-4116-8655-21cbb9de8358"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[6.3378938e-02 4.6831051e-01 4.6831051e-01]\n",
      " [6.0336647e-06 9.8200780e-01 1.7986100e-02]\n",
      " [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.  0.5 0.5]\n",
      " [0.  1.  0. ]\n",
      " [0.  0.9 0.1]], shape=(3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)\n",
    "print(attn_scores_softmax)\n",
    "\n",
    "# For readability, approximate the above as follows\n",
    "attn_scores_softmax = [\n",
    "  [0.0, 0.5, 0.5],\n",
    "  [0.0, 1.0, 0.0],\n",
    "  [0.0, 0.9, 0.1]\n",
    "]\n",
    "attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)\n",
    "print(attn_scores_softmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 230
    },
    "id": "TOJMfkFpi0KQ",
    "outputId": "8de18989-50d7-4534-cf5c-2711c66d17ce"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[[0.  0.  0. ]\n",
      "  [0.  0.  0. ]\n",
      "  [0.  0.  0. ]]\n",
      "\n",
      " [[1.  4.  0. ]\n",
      "  [2.  8.  0. ]\n",
      "  [1.8 7.2 0. ]]\n",
      "\n",
      " [[1.  3.  1.5]\n",
      "  [0.  0.  0. ]\n",
      "  [0.2 0.6 0.3]]], shape=(3, 3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "weighted_values = values[:,None] * tf.transpose(attn_scores_softmax)[:,:,None]\n",
    "print(weighted_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "id": "jan_cyy7i-s7",
    "outputId": "09b1406f-3a08-47e2-8dee-d4d6334ef1de"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[2.        7.        1.5      ]\n",
      " [2.        8.        0.       ]\n",
      " [2.        7.7999997 0.3      ]], shape=(3, 3), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "outputs = tf.reduce_sum(weighted_values, axis=0)  # 6\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "basic_self-attention .ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}